feat: Expose prediction SSE streams#3019
Conversation
There was a problem hiding this comment.
I'm Bonk, and I've done a quick review of your PR.
This PR adds Server-Sent Events (SSE) streaming for predictions with event replay for late subscribers. The implementation looks solid overall with good test coverage. I found one potential resource concern that should be addressed:
Issues
1. Unbounded memory growth in stream_history (Medium)
The stream_history Vec in Prediction grows unbounded during prediction execution. For long-running predictions with frequent output chunks, logs, or metrics, this could cause memory exhaustion. Consider adding a limit to the history size or making it configurable.
|
I've completed my review of PR #3019. The PR adds SSE streaming for predictions with event replay functionality. The implementation is well-structured and includes comprehensive tests. Summary of findings: I identified one medium-severity issue regarding unbounded memory growth in the The rest of the implementation looks solid:
|
|
Based on my review of PR #3019, I agree with the previous reviewer's finding about unbounded memory growth in The implementation is well-structured:
LGTM |
cfa865e to
c61e323
Compare
|
Based on my thorough review of PR #3019, I have analyzed all the changed files: Summary of changes: Previous findings: My independent assessment:
No additional actionable issues beyond what was already flagged. LGTM |
|
After reviewing PR #3019, I've analyzed the changes to add Server-Sent Events (SSE) streaming for predictions. The implementation reuses Key aspects reviewed:
Previously flagged issue: No additional actionable issues found beyond what was already flagged. LGTM |
|
Based on my thorough review of PR #3019, I can see this PR adds Server-Sent Events (SSE) streaming for predictions by reusing Key changes reviewed:
Previously flagged issue (addressed): Additional verification:
LGTM |
|
Based on my review of PR #3019, I can confirm that the previously flagged issue regarding unbounded memory growth in Verified fixes:
Additional verification:
The implementation is well-structured with comprehensive test coverage including unit tests for lagged streams, truncated replay, and integration tests. LGTM |
michaeldwan
left a comment
There was a problem hiding this comment.
Solid design. Main point discussed offline: streaming needs to be opt-in at the handler level via a decorator/attribute. Five concrete changes in the comments.
|
|
||
| fn emit_stream_event(&mut self, event: PredictionStreamEvent) { | ||
| if self.stream_history.len() == MAX_STREAM_HISTORY_EVENTS { | ||
| self.stream_history.remove(0); |
There was a problem hiding this comment.
Vec::remove(0) is O(n) on every event after the buffer fills -- thousands of shifts per prediction for token-by-token LLM output. VecDeque::pop_front() is O(1).
| prediction_id.clone(), | ||
| input.clone(), | ||
| webhook_sender, | ||
| response_mode != PredictionResponseMode::AsyncJson, |
There was a problem hiding this comment.
response_mode != PredictionResponseMode::AsyncJson is true for SyncJson too. Should be response_mode == PredictionResponseMode::AsyncSse -- the current expression is inert for sync mode but reads as a bug.
|
|
||
| impl Prediction { | ||
| pub fn new(id: String, webhook: Option<WebhookSender>) -> Self { | ||
| let (stream_tx, _) = tokio::sync::broadcast::channel(1024); |
There was a problem hiding this comment.
Channel capacity and history cap are both 1024 but defined independently. Use a shared constant so they can't drift.
| return; | ||
| } | ||
|
|
||
| if self.service.stream_receiver_count(&self.id) == 0 |
There was a problem hiding this comment.
The async cleanup task calls remove_prediction while the SSE client may still be draining. After removal, stream_receiver_count returns unwrap_or(0) and prediction_is_terminal returns unwrap_or(true) -- which happen to be the safe defaults. That invariant needs a comment, or hold an Arc to the entry in the guard so it doesn't depend on post-removal lookup.
| cog serve --upload-url http://unused/ | ||
|
|
||
| curl -H Accept:text/event-stream PUT /predictions/sse-stream-test '{"id":"sse-stream-test","input":{}}' | ||
| stdout 'event: output' |
There was a problem hiding this comment.
Missing stdout 'event: start' -- should verify the full lifecycle.
There was a problem hiding this comment.
I'm Bonk, and I've done a quick review of your PR.
This PR adds Server-Sent Events (SSE) streaming for predictions with event replay for late subscribers. The implementation is well-structured overall.
Summary
I've verified the issues flagged by previous reviewers:
- Vec::remove(0) O(n) issue - Fixed: Now uses
VecDeque::pop_front()which is O(1) - Unbounded memory growth - Fixed:
STREAM_EVENT_BUFFER_CAPACITY(1024) limits both broadcast channel and history - Integration test coverage - The test at
integration-tests/tests/sse_streaming_output.txtarincludesstdout 'event: start'as suggested
The implementation correctly:
- Distinguishes sync vs async cancellation semantics via
PredictionStreamGuard - Handles late subscribers with bounded replay (1024 events)
- Includes 15s keep-alive for SSE streams
- Has comprehensive test coverage including lagged stream and truncated replay scenarios
LGTM - all previously flagged issues have been addressed.
|
Posted review for PR #3019. All previously flagged issues have been addressed:
|
|
Based on my review of PR #3019, I've verified that all previously flagged issues have been addressed: Verified fixes:
Implementation highlights:
LGTM |
|
LGTM |
|
Based on my thorough review of PR #3019, I've verified that all previously flagged issues have been addressed: Verified fixes:
Implementation highlights:
LGTM |
|
@michaeldwan decorator added in 8c9c982 |
michaeldwan
left a comment
There was a problem hiding this comment.
Summary
Solid design. The broadcast channel + bounded replay architecture is sound, the opt-in gating via @cog.streaming is the right call, and the integration between Go parser → OpenAPI extension → coglet runtime is clean. The replay/live transition has no race (mutex held atomically), the bounded history buffer has no off-by-one, and completed events are emitted on all terminal paths.
One blocker, several should-fixes. Details below and inline.
Blockers
-
examples/streaming-text/predict.pydoesn't use@cog.streaming-- the README tells users to curl withAccept: text/event-stream, which will 406. The integration testsse_streaming_output.txtarcorrectly uses@streaming, but the shipped example doesn't. (See inline comment.) -
No test for SSE events during failed or cancelled predictions.
set_failed()andset_canceled()both emitCompletedevents, but no test verifies the SSE stream deliversevent: completedwith"status":"failed"/"status":"canceled". These are critical user-facing paths.
Should-fix
decoratorIsCogStreaminghard-codes"cog.streaming"instead of resolving throughImportContext.import cog as c→@c.streamingwon't be detected. The rest of the parser handles aliases. (See inline comment.)@cog.streaming()with parens silently degrades. Parser rejects call form, but Python decorator works either way. User gets a working model where SSE returns 406 with no hint about the parens.- No limit on SSE subscriber count per prediction.
subscribe_prediction_stream()creates a new broadcast receiver with no cap. Repeated SSE connections to the same prediction ID amplify memory pressure. - Orphaned
pending_cancellationsleak memory. Cancel messages arriving after a prediction completes get stored in theHashSetand never cleaned up. (See inline comment.) - Double-clone on every stream event.
emit_stream_eventclonesserde_json::Valuefor history storage.Arc<PredictionStreamEvent>would eliminate deep clones -- history and broadcast share the same allocation. Also fixes the O(n) deep-clone insubscribe_stream_replay()under the mutex. (See inline comment.) PredictionStreamGuard::Dropcallstokio::spawnviacancel(). If dropped outside a tokio runtime context,tokio::spawnpanics. UseHandle::try_current().- No test for concurrent SSE subscribers. The guard checks
stream_receiver_count() == 0before cancelling, but no test verifies dropping one of two subscribers doesn't cancel. - Training endpoints silently ignore
Accept: text/event-stream. Returning 406 or documenting would be more honest than silent fallback to JSON.
Nits
RegisterPredictionMessage4-element tuple → named struct.streaminglisted under# Metricsin__all__-- it's a decorator.- Module-level
FTypeVar →_Fto signal internal. replay.into()creates unnecessaryVecDequefromVec.id.to_string()allocated twice insubscribe_prediction_stream.- Missing
require.NotNilguard before type assertions in streaming OpenAPI tests. - Broadcast channel capacity and history buffer both use the same 1024 constant by coincidence -- give them separate named constants.
Verified correct
- Replay + live transition (no race -- mutex held atomically during subscribe + snapshot)
- Bounded history buffer (no off-by-one)
completedevent emitted on all terminal paths- Terminal state guards prevent double-completion
findTargetFunctionreturningdecorated_definitionhandled correctly viaUnwrapFunctioncog predictCLI works fine with streaming models (uses sync JSON path)docs/python.mdanddocs/llms.txtare accurate and in sync
| ), | ||
| ) -> Iterator[str]: | ||
| messages = [{"role": "user", "content": prompt}] | ||
| text = self.tokenizer.apply_chat_template( |
There was a problem hiding this comment.
Blocker: This method is missing @cog.streaming. The README (line 28-31) tells users to curl with Accept: text/event-stream, which will return 406 since the model doesn't opt in.
The integration test sse_streaming_output.txtar correctly uses @streaming, but this shipped example doesn't.
from cog import BasePredictor, Input, streaming
class Predictor(BasePredictor):
# ...
@streaming
def predict(self, ...) -> Iterator[str]:| self.stream_history.pop_front(); | ||
| self.stream_history_skipped += 1; | ||
| } | ||
| self.stream_history.push_back(event.clone()); |
There was a problem hiding this comment.
Should-fix: This clones the event (containing serde_json::Value) for history, then moves the original into broadcast::send. For high-throughput models yielding many chunks, this deep-clones arbitrarily large JSON on every output.
Consider Arc<PredictionStreamEvent> for the broadcast channel type -- history and broadcast share the same allocation, and subscribe_stream_replay() becomes 1024 atomic increments instead of 1024 deep JSON clones under the prediction mutex.
stream_tx: broadcast::Sender<Arc<PredictionStreamEvent>>,
stream_history: VecDeque<Arc<PredictionStreamEvent>>,
fn emit_stream_event(&mut self, event: PredictionStreamEvent) {
// ...
let event = Arc::new(event);
self.stream_history.push_back(Arc::clone(&event));
let _ = self.stream_tx.send(event);
}| None => { | ||
| tracing::debug!(%prediction_id, "Cancel requested for unknown prediction (may have already completed)"); | ||
| tracing::debug!(%prediction_id, "Cancel requested for unknown prediction; storing pending cancellation"); | ||
| pending_cancellations.insert(prediction_id); |
There was a problem hiding this comment.
Should-fix: If the cancel arrives after the prediction has already completed and been removed from predictions, the ID is stored here and never consumed. In a long-running server with many cancelled predictions, this is an unbounded leak.
Consider adding a size cap (e.g., 1000 entries) or a TTL, and log a warning when it's exceeded.
| func decoratorIsCogStreaming(node *sitter.Node, source []byte, imports *schema.ImportContext) bool { | ||
| for _, child := range NamedChildren(node) { | ||
| switch child.Type() { | ||
| case "attribute": |
There was a problem hiding this comment.
Should-fix: Hard-coded string match. import cog as c then @c.streaming won't be detected -- the content will be "c.streaming", not "cog.streaming".
The rest of the parser resolves aliases through ImportContext (e.g., IsBaseModel, IsOpaque). This should do the same:
case "attribute":
text := Content(child, source)
parts := strings.SplitN(text, ".", 2)
if len(parts) != 2 || parts[1] != "streaming" {
return false
}
entry, ok := imports.Names.Get(parts[0])
return ok && entry.Module == "cog" && entry.Original == "cog"| entry, ok := imports.Names.Get("streaming") | ||
| return ok && entry.Module == "cog" && entry.Original == "streaming" | ||
| case "call": | ||
| return false |
There was a problem hiding this comment.
Should-fix: This rejects @cog.streaming() (call form), but the Python decorator works fine with either @streaming or @streaming(). A user who writes @cog.streaming() gets a model that builds, runs, and yields output -- but SSE returns 406 with no hint about the parentheses.
Either support the call form here (check if the callee is cog.streaming or imported streaming), or make the Python decorator raise a clear error when called with parens. The current behavior is a silent gotcha.
| self: &Arc<Self>, | ||
| id: &str, | ||
| ) -> Option<PredictionStreamSubscription> { | ||
| let entry = self.predictions.get(id)?; |
There was a problem hiding this comment.
Should-fix: No cap on subscriber count. The idempotent PUT endpoint allows repeated SSE connections to the same prediction ID, each creating a new broadcast receiver. An attacker opening many connections forces the sender to retain events for slow consumers.
Consider:
let pred = entry.prediction.lock().ok()?;
if pred.stream_receiver_count() >= MAX_STREAM_SUBSCRIBERS {
return None;
}| // Prediction cleanup may remove the service entry before the SSE response | ||
| // finishes draining. Missing entries deliberately report zero receivers and | ||
| // terminal state so this guard cannot cancel an already-cleaned prediction. | ||
| if self.service.stream_receiver_count(&self.id) == 0 |
There was a problem hiding this comment.
Should-fix: cancel() internally calls tokio::spawn. If this guard is dropped outside a tokio runtime context (panic unwinding, runtime shutdown), tokio::spawn panics.
Safer:
if let Ok(handle) = tokio::runtime::Handle::try_current() {
let service = Arc::clone(&self.service);
let id = self.id.clone();
handle.spawn(async move { service.cancel(&id); });
}Or refactor cancel() itself to use Handle::try_current() before spawning.
| "CancelationException", | ||
| # Metrics | ||
| "current_scope", | ||
| "streaming", |
There was a problem hiding this comment.
Nit: streaming is a decorator, not a metric. Move it to its own # Decorators section or under # Core classes.
| URLPath, | ||
| ) | ||
|
|
||
| F = TypeVar("F", bound=Callable[..., object]) |
There was a problem hiding this comment.
Nit: This is importable as from cog import F. Prefix with underscore (_F) to signal it's internal.
| } | ||
|
|
||
| type RegisterPredictionMessage = ( | ||
| SlotId, |
There was a problem hiding this comment.
Nit: A 4-element tuple type alias is opaque -- the ack channel's purpose is invisible at usage sites. A named struct would be clearer:
struct RegisterPredictionMessage {
slot_id: SlotId,
prediction: Arc<StdMutex<Prediction>>,
idle_sender: oneshot::Sender<SlotIdleToken>,
registered_ack: oneshot::Sender<()>,
}|
@copilot resolve the merge conflicts in this pull request |
Co-authored-by: markphelps <209477+markphelps@users.noreply.github.com>
|
@markphelps Bonk workflow was cancelled. View workflow run · To retry, trigger Bonk again. |
|
/bonk review latest commits |
|
Based on my thorough review of the latest commits in PR #3019, here's my assessment: SummaryAll previously flagged review issues have been addressed. The implementation is solid and well-tested. Verified Fixes from Previous ReviewsBlockers (all resolved):
Should-fix items (all resolved):
Nits (all resolved):
Notable Implementation Details
Minor NoteThe github-code-quality bot flagged the LGTM - This PR is ready to merge. All review feedback has been addressed comprehensively. |
|
i'm interested in working on WebSocket support after this lands, so that we can have a slightly-more-efficient means to stream responses in and out of the container (with not much idea of the limits of this until i actually begin it). |
michaeldwan
left a comment
There was a problem hiding this comment.
Previous review feedback is fully addressed — every blocker and should-fix from 2026-05-19 is in with regression tests. The architecture (broadcast + bounded replay + Arc<PredictionStreamEvent> + receiver-before-guard drop order) holds up under a fresh pass.
One regression to fix before merge: a decorated train() method in a class now fails to parse (inline). The should-fixes below are real but none are correctness-critical for the SSE happy path.
Blocker
findMethodInClassreturnsdecorated_definition— breaks decoratedtrainmethods. Reproduced locally withErrParse: function has no parameters node. Inline.
Should-fix
- SSE subscription failure leaves prediction running unobserved (inline on
routes.rs) - Drop-order invariant in
StreamState/PredictionStreamSubscriptionis load-bearing but undocumented (inline) - No test for train mode +
SupportsStreaming=true—openapi.go:220guards with!isTrain && SupportsStreaming, no coverage for the omit path (inline onopenapi_test.go) @streamingon a non-iterator return type silently produces broken SSE — the parser only inspects the decorator, never the return annotation (inline onparser.go)examples/streaming-text/predict.pyuses deprecatedBasePredictor/predict()even thoughdocs/python.mddirects new code toBaseRunner/run()(inline)__cog_streaming__attribute is dead code — nothing reads it, but two unit tests assert it (inline)
Nits
PredictionService::cancelsilently drops worker-side cancel ontry_read()contention (inline)subscribe_prediction_streammaps poisoned mutex toNotFound→ misleading 404 (inline)- Aliased
from cog import streaming as streamthen@streamis not detected; attribute form already handles aliasing — asymmetric (inline) - Example's
thread.join()is unreachable on early consumer exit → GPU memory leak in the copy-paste case (inline) - Docs list four spellings as if the user has a meaningful choice (inline on
docs/python.md)
| nameNode := funcNode.ChildByFieldName("name") | ||
| if nameNode != nil && Content(nameNode, source) == methodName { | ||
| return funcNode, nil | ||
| return child, nil |
There was a problem hiding this comment.
Blocker: Returning child here breaks decorated train() methods in classes. ParsePredictor (line 80) takes target.node directly without unwrapping, then line 88 calls funcNode.ChildByFieldName("parameters") — that field exists on function_definition, not decorated_definition, so any decorated train method fails to parse.
Reproduced locally:
class Trainer(BasePredictor):
@functools.wraps(lambda x: x)
def train(self, n: int) -> Path: ...→ ErrParse: function has no parameters node.
functionSupportsStreaming already handles a function_definition argument by walking node.Parent() (parser.go:675-681), so returning funcNode here is correct and streaming detection still works.
Fix:
return funcNode, nilPlus a regression test for a decorated train method — that gap is why CI missed this.
| } else { | ||
| None | ||
| }; | ||
|
|
There was a problem hiding this comment.
Should-fix: The predict task is spawned before checking whether subscribe_prediction_stream returned Ok. If subscription fails (TooManySubscribers on burst, NotFound on a cleanup race), the prediction still runs to completion, fires webhooks, consumes a slot, and is cleaned up with no client watching.
Either subscribe before spawn and gate the spawn on Ok, or cancel-and-await on subscription failure.
| skipped: u64, | ||
| receiver: tokio::sync::broadcast::Receiver<SharedPredictionStreamEvent>, | ||
| guard: PredictionStreamGuard, | ||
| } |
There was a problem hiding this comment.
Should-fix: Drop order is load-bearing here: receiver must drop before guard, otherwise stream_receiver_count(&id) in the guard's Drop still sees this receiver (≥1) and cancel_on_stream_drop never fires. A future field reorder silently breaks SSE cancel-on-disconnect.
The dropping_one_of_two_sync_stream_subscriptions test would catch it, but a comment is what protects refactors that don't run the full suite:
// IMPORTANT: drop order matters — receiver must drop before guard so
// stream_receiver_count() returns 0 by the time the guard runs cleanup.Same comment belongs on StreamState in routes.rs:681-687.
| require.NotNil(t, postPath) | ||
| post := postPath.(map[string]any) | ||
| assert.Equal(t, true, post["x-cog-streaming"]) | ||
| } |
There was a problem hiding this comment.
Should-fix: openapi.go:220 guards x-cog-streaming with !isTrain && info.SupportsStreaming, but no test exercises isTrain=true && SupportsStreaming=true to confirm the extension is omitted on /trainings. Easy to regress if someone refactors the guard. Mirror this test with Mode: ModeTrain and assert the extension is absent.
| return nil | ||
| } | ||
|
|
||
| func functionSupportsStreaming(node *sitter.Node, source []byte, imports *schema.ImportContext) bool { |
There was a problem hiding this comment.
Should-fix: functionSupportsStreaming only inspects the decorator, never the return type. A user who applies @streaming to def predict(self) -> str gets a schema claiming streaming support; SSE clients receive a single chunk + completed and conclude the model is broken.
Enforce here: require an Iterator[...] / AsyncIterator[...] / *ConcatenateIterator[...] return annotation before reporting SupportsStreaming = true. The return type is already parsed at line 108-115 — this is the natural enforcement point. Document the constraint in docs/python.md.
| ); | ||
| } | ||
| }); | ||
| spawn_orchestrator_cancel(orch, id_owned); |
There was a problem hiding this comment.
Nit: If self.orchestrator.try_read() can't acquire the lock, the worker-side cancel is silently dropped — the CancellationToken fires but the subprocess never gets ControlRequest::Cancel. Today fine (only one .write() at startup), but a tracing::warn!(prediction_id = %id, "Skipped worker cancel: orchestrator lock unavailable") would make any future regression diagnosable.
| let prediction = entry | ||
| .prediction | ||
| .lock() | ||
| .map_err(|_| SubscribePredictionStreamError::NotFound)?; |
There was a problem hiding this comment.
Nit: Poisoned mutex maps to SubscribePredictionStreamError::NotFound → returns 404. A 404 for a poisoned mutex is misleading at the diagnostic level. Either use try_lock_prediction (which attempts recovery and fails the prediction explicitly) or add a dedicated error variant.
| return false | ||
| } | ||
| entry, ok := imports.Names.Get("streaming") | ||
| return ok && entry.Module == "cog" && entry.Original == "streaming" |
There was a problem hiding this comment.
Nit: from cog import streaming as stream then @stream is not detected — this rejects anything where the local identifier isn't literally "streaming". The attribute form (line 720-727) already handles aliasing via entry.Original == "cog", so the asymmetry is mildly surprising.
One-line fix mirrors the attribute form:
func identifierIsCogStreaming(node *sitter.Node, source []byte, imports *schema.ImportContext) bool {
entry, ok := imports.Names.Get(Content(node, source))
return ok && entry.Module == "cog" && entry.Original == "streaming"
}| if chunk: | ||
| yield chunk | ||
|
|
||
| thread.join() |
There was a problem hiding this comment.
Nit: thread.join() is unreachable if the consumer of the Iterator[str] (the coglet worker) stops iterating early — e.g. prediction cancellation. The for chunk in streamer: loop exits via GeneratorExit, the join is skipped, and the generation thread keeps running until max_new_tokens completes, holding GPU memory.
Best practice for an example users will copy:
try:
for chunk in streamer:
if chunk:
yield chunk
finally:
thread.join()(or a comment noting the limitation.)
|
|
||
| To support streaming output in your Cog model, add `from typing import Iterator` to your `run.py` file. The `typing` package is a part of Python's standard library so it doesn't need to be installed. Then add a return type annotation to the `run()` method in the form `-> Iterator[<type>]` where `<type>` can be one of `str`, `int`, `float`, `bool`, or `cog.Path`. | ||
|
|
||
| To allow clients to receive chunks as server-sent events with `Accept: text/event-stream`, decorate the prediction method (`run()` or `predict()`) with `@cog.streaming`, `@cog.streaming()`, `@streaming`, or `@streaming()` imported from `cog`. Without the decorator, iterator outputs still work in normal JSON responses, but SSE requests return `406 Not Acceptable`. |
There was a problem hiding this comment.
Nit: Listing all four spellings (@cog.streaming, @cog.streaming(), @streaming, @streaming()) reads like the user has a meaningful choice between four things. Consider: "decorate with @cog.streaming (or @streaming if imported directly); the parenthesized form @streaming() is also accepted." Same change in docs/llms.txt:1982.
|
@copilot resolve the merge conflicts in this pull request |
Co-authored-by: markphelps <209477+markphelps@users.noreply.github.com>
|
Based on my thorough review of PR #3019, I can see that all the previously flagged issues from the review history have been comprehensively addressed: Summary of Verified FixesAll blockers and should-fix items from michaeldwan's 2026-05-19 review are resolved:
The implementation is solid with comprehensive test coverage and all review feedback has been addressed. LGTM |
|
please @michaeldwan |
Summary
POST /predictionsandPUT /predictions/{id}for Server-Sent Events by returning an SSE stream when requests sendAccept: text/event-stream.@streaming/@cog.streamingopt-in while preserving existing iterator JSON behavior for non-SSE requests.Prefer: respond-asyncwithout SSE still returns202JSON.Streaming behavior
start,output,log,metric, and terminalcompletedevents.PUT /predictions/{id}andAccept: text/event-stream.COG_STREAM_HISTORY_CAPACITYto configure per-prediction SSE replay history;0disables replay while keeping live streaming enabled.406 Not Acceptablefor SSE requests when a model has not opted in, and for training endpoints which do not support SSE.CLI behavior
cog predict --streamoption or otherwise change the predict CLI.Docs and tests
docs/environment.mdfor public Cog-specific environment variables, including the new SSE replay history setting.